# Usage: python datasets/scripts/preprocess_shapenet.py --source path/to/source

import json
import numpy as np
import os
from tqdm import tqdm
import argparse

def open_cv_to_open_gl(cam_matrix):
        '''
        Transform camera transformation matrix axis.
        '''
        reverse = np.diag([1, -1, -1, 1])
        cam_matrix =  cam_matrix @ reverse
        return cam_matrix


def open_gl_to_open_cv(cam_matrix):
        return open_cv_to_open_gl(cam_matrix)


def list_recursive(folderpath):
    return [os.path.join(folderpath, filename) for filename in os.listdir(folderpath)]


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--source", type=str)
    parser.add_argument("--max_images", type=int, default=None)
    #args = parser.parse_args()
    args = parser.parse_args(args=["--source", "path/to/source"])
    # Parse cameras
    dataset_path = args.source

    for scene_folder_path in tqdm(list_recursive(dataset_path)):
        if not os.path.isdir(scene_folder_path): continue

        transforms = {
            "camera_model": "OPENCV",
            "frames": []
        }

        transforms_opencv = {
            "frames": []
        }
        
        
        for i, rgb_path in enumerate(list_recursive(os.path.join(scene_folder_path, 'rgb'))):
            relative_path = os.path.relpath(rgb_path, scene_folder_path)
            intrinsics_path = os.path.join(scene_folder_path, 'intrinsics.txt')
            pose_path = rgb_path.replace('rgb', 'pose').replace('png', 'txt')
            assert os.path.isfile(rgb_path)
            assert os.path.isfile(intrinsics_path)
            assert os.path.isfile(pose_path)
            
            with open(pose_path, 'r') as f:
                pose_open_cv = np.array([float(n) for n in f.read().split(' ')]).reshape(4, 4)
                pose_open_gl = open_cv_to_open_gl(pose_open_cv)
                
            with open(intrinsics_path, 'r') as f:
                first_line = f.read().split('\n')[0].split(' ')
                focal = float(first_line[0]) 
                cx = float(first_line[1])
                cy = float(first_line[2])
                            
                orig_img_size = 512  # cars_train has intrinsics corresponding to image size of 512 * 512
                img_size = 128

            downscale_factor =  orig_img_size / img_size
            
            transforms["frames"].append({
                "file_path" : relative_path, 
                "transform_matrix": pose_open_gl.tolist(),
                "fl_x": focal/downscale_factor,
                "fl_y": focal/downscale_factor, 
                "cx": cx/downscale_factor,
                "cy": cy/downscale_factor,
                "w": img_size,
                "h": img_size,
            })

            transforms_opencv["frames"].append({
                "file_path" : relative_path, 
                "transform_matrix": pose_open_cv.tolist(),
                "fl_x": focal/downscale_factor,
                "fl_y": focal/downscale_factor, 
                "cx": cx/downscale_factor,
                "cy": cy/downscale_factor,
                "w": img_size,
                "h": img_size,
            })
    
        with open(os.path.join(scene_folder_path, 'transforms.json'), 'w') as outfile:
            json.dump(transforms, outfile, indent=4)


        with open(os.path.join(scene_folder_path, 'transforms_opencv.json'), 'w') as outfile:
            json.dump(transforms_opencv, outfile, indent=4)

